{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Practical Session. Variational Autoencoders\n",
    "\n",
    "During this practical session you will implement a vanilla variational autoencoder on MNIST and then a modification of VAE that can be used for semi-supervised learning. To emphasize the probabilistic nature of the models, both implementations will be based on classes for parametric probabilistic distributions that generally follow the design of *tensorflow.distributions*, *torch.distributions*.\n",
    "\n",
    "To complete the notebook you will have to implement several classes for distributions and then use them to construct the stochastic graph for loss function computation and then train the models.\n",
    "\n",
    "# AEs  vs. VAEs\n",
    "\n",
    "Although autoencoders can provide good reconstruction quality, \n",
    "\n",
    "![Autoencoder reconstructions](ae_reconstructions.png)\n",
    "\n",
    "the model has no control over the learned latent representations. For example, an interpolation of latent representaitons of two digit is typically not a latent representation for a digit:\n",
    "\n",
    "![Autoencoder interpolations](ae_interpolations.png)\n",
    "\n",
    "On the other hand, a standard VAE model forces latent representation to fit multivariate Gaussian distribution. As a result, an interpolation of two latent representations is likely to be a latent representation of a digit."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are using google collab install the libraries with \n",
    "# !pip3 install torch\n",
    "# !pip3 install torchvision\n",
    "from torchvision.datasets import MNIST\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import numpy as np\n",
    "import matplotlib.pylab as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the semi-supervised learning task we remove 95% of labels from the training set. In the modified training set the observed labels have a standard one-hot encoding and the unobserved labels are represented by all-zero ten dimensional vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = MNIST('mnist', download=True, train=True)\n",
    "new_train_labels = torch.zeros(60000, 10)\n",
    "observed = np.random.choice(60000, 3000)\n",
    "new_train_labels[observed] = torch.eye(10)[data.train_labels][observed]\n",
    "train_data = TensorDataset(data.train_data.view(-1, 28 * 28).float() / 255,\n",
    "                           new_train_labels)\n",
    "\n",
    "data = MNIST('mnist', download=True, train=False)\n",
    "test_data = TensorDataset(data.test_data.view(-1, 28 * 28).float() / 255,\n",
    "                          data.test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributions for VAE\n",
    "\n",
    "To define the probabilistic model for VAE we need two types of distributions. For latent variables $z$ we need a multivariate normal distribution with diagonal covariance matrix (to put other way, a vector of independent normal random variables). For observarions $x$ we need a vector of independent Bernoulli random variables.\n",
    "\n",
    "All the distributions we need are already implemented in *torch.distributions* and can be easily used in practice. At this practical session, for educational purposes, we suggest to implement it yourself.\n",
    "\n",
    "## Bernoulli random vector\n",
    "\n",
    "It is a good practice to store the distribution parameters as logits $l(x=1)$, because it helps to avoid computation of logarithms of very small values. Logits are converted to probabilities with sigmoid function $p(x=1) = \\frac{1}{1 + \\exp(-l(x = 1))}$.\n",
    "\n",
    "### Class details\n",
    "\n",
    "- *self.logits* is a $N \\times D$-dimensional tensor of logits for each of $D$ pixels of $N$ object in batch.\n",
    "\n",
    "- *log_prob(x)* takes a tensor of observations of size $N \\times D$ (batch size $\\times$ vector dimensionality). Returns an $n$-dimensional vector of logarithms of observation probabilisties for each object in the batch:\n",
    "\n",
    "\\begin{equation}\n",
    "\\sum_{d = 1}^{D} \\left( x_{nd} \\log p(x_{nd} = 1) + (1 - x_{nd}) \\log p(x_{nd} = 0) \\right), \\enspace n=1, \\dots, N\n",
    "\\end{equation}\n",
    "\n",
    "- *sample( )* returns a $N \\times D$ binary vector sampled from the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BernoulliVector():\n",
    "    def __init__(self, logits):\n",
    "        self.logits = logits\n",
    "\n",
    "    def log_prob(self, x):\n",
    "        ##################################\n",
    "        #         implement it           #\n",
    "        ##################################\n",
    "        pass\n",
    "    \n",
    "    def sample(self):\n",
    "        ##################################\n",
    "        #         implement it           #\n",
    "        ##################################\n",
    "        pass\n",
    "        # torch.rand_like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you can check your implementation with the following test:\n",
    "def test_BernoulliVector():\n",
    "    logits = torch.tensor([[ 0.26257313,  1.00010365,  1.32164169, -0.60049884,  0.47478581],\n",
    "                           [-0.69943423,  0.40572153,  0.91215638,  1.36048238,  0.28434441],\n",
    "                           [ 0.11055949, -0.65058279, -1.74598369,  1.2715774 , -0.60143489]])\n",
    "    bv = BernoulliVector(logits)\n",
    "    # test log_prob()\n",
    "    x = torch.tensor([[0, 1, 1, 0, 1],\n",
    "                      [0, 0, 0, 0, 0],\n",
    "                      [0, 1, 1, 0, 1]], dtype=torch.float32)\n",
    "    log_probs = bv.log_prob(x)\n",
    "    assert(log_probs.shape[0] == 3), 'log_prob() returns wrong shape'\n",
    "    \n",
    "    true_log_probs = np.asarray([-2.3037, -5.0039, -6.2844], dtype=np.float32)\n",
    "    np.testing.assert_allclose(log_probs.numpy(), true_log_probs, atol=1e-4,\n",
    "                               err_msg='log_prob() returns wrong values')\n",
    "    # test sample()\n",
    "    assert(logits.shape == bv.sample().shape), 'sample() returns wrong shape'\n",
    "    \n",
    "    mean = torch.zeros_like(logits)\n",
    "    for i in range(1024):\n",
    "        mean += 1 / 1024 * bv.sample().type(torch.float32)\n",
    "        \n",
    "    np.testing.assert_allclose(nn.functional.sigmoid(logits).numpy(), mean, atol=1e-1,\n",
    "                               err_msg='law of large number seems to be violated by sample()')\n",
    "    \n",
    "    print(\"All fine!\")\n",
    "    \n",
    "test_BernoulliVector()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multivariate Normal Distribution\n",
    "\n",
    "### Class details\n",
    "\n",
    "- *self.loc, self.scale* are $N\\times d$-dimensional tensors that store mean and standard deviation for $d$ components of the latent representation for $N$ object in the batch\n",
    "\n",
    "- *log_prob(z)* takes a $N \\times d$-dimensional tensor and returns log-density for each object in the batch\n",
    "\\begin{equation}\n",
    "\\sum_{d=1}^D \\log \\mathcal{N}(z_{nd} \\mid \\mu_{nd}, \\sigma_{nd}), \\enspace n=1, \\dots, N\n",
    "\\end{equation}\n",
    "- *sample( )* returns a $N \\times d$-dimensional tensor samples from the distribution. **For the reparametrization trick to work properly, the sampling procedure has to be a deterministic function $f$ of a random tensor $\\epsilon$ and the model parameters $\\theta$: $z = f(\\epsilon, \\theta)$**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to implement\n",
    "class MultivariateNormalDiag():\n",
    "    def __init__(self, loc=None, scale=None):\n",
    "        self.loc = loc\n",
    "        self.scale = scale\n",
    "        \n",
    "    def log_prob(self, z):\n",
    "        ##################################\n",
    "        #         implement it           #\n",
    "        ##################################\n",
    "        pass\n",
    "        \n",
    "    def sample(self):\n",
    "        ##################################\n",
    "        #         implement it           #\n",
    "        ##################################\n",
    "        # torch.randn_like\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_MultivariateNormalDiag():\n",
    "    mean = torch.tensor([[ 0.0619,  1.9728,  0.2092],\n",
    "                         [ 0.3971, -0.1817,  1.1508]])\n",
    "    scale = torch.tensor([[ 0.0619, 1.9728,  0.2092],\n",
    "                          [ 0.3971, 0.1817,  1.1508]])\n",
    "    mnd = MultivariateNormalDiag(mean, scale)\n",
    "    # test log_prob()\n",
    "    x = torch.ones_like(mean)\n",
    "    log_probs = mnd.log_prob(x)\n",
    "    assert(log_probs.shape[0] == 2), 'log_prob() returns wrong shape'\n",
    "    \n",
    "    true_log_probs = np.asarray([-121.1941,  -22.5777], dtype=np.float32)\n",
    "    np.testing.assert_allclose(log_probs.numpy(), true_log_probs, atol=1e-4,\n",
    "                               err_msg='log_prob() returns wrong values')\n",
    "    # test sample()\n",
    "    assert(mean.shape == mnd.sample().shape), 'sample() returns wrong shape'\n",
    "    \n",
    "    est_mean = torch.zeros_like(mean)\n",
    "    for i in range(1024):\n",
    "        est_mean += 1 / 1024 * mnd.sample()\n",
    "        \n",
    "    np.testing.assert_allclose(est_mean, mean, atol=1e-1,\n",
    "                               err_msg='law of large number seems to be violated by sample()')\n",
    "    \n",
    "    print(\"All fine!\")\n",
    "    \n",
    "test_MultivariateNormalDiag()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vanilla VAE\n",
    "\n",
    "## A brief VAE description\n",
    "\n",
    "A variational autoencoder consists of two components. The first is a probabilistic model for observations: \n",
    "\\begin{align}\n",
    "& p(x, z \\mid \\theta) =  p(z) p(x \\mid z, \\theta) \\\\\n",
    "& p(z) = \\mathcal N(z \\mid 0, I) \\\\\n",
    "& p(x \\mid z, \\theta) = \\prod_{i = 1}^D p_i(z, \\theta)^{x_i} (1 - p_i(z, \\theta))^{1 - x_i}.\n",
    "\\end{align}\n",
    "The second is a variational approximation, used to compute the lower bound on marginal likelihood (our loss function)\n",
    "\\begin{equation}\n",
    "q(z \\mid x, \\phi) = \\mathcal N(z \\mid \\mu(x, \\phi), \\operatorname{diag}(\\sigma^2(x, \\phi)))\n",
    "\\end{equation}\n",
    "\n",
    "The lower bound for one object $x$ from the minibatch has form\n",
    "$$ \\mathcal L(x, \\theta, \\phi) = \\mathbb E_{q(z \\mid x, \\phi)} \\left[ \\log p(x \\mid z, \\phi) + \\log p(z) - \\log q(z \\mid x, \\theta) \\right] $$\n",
    "\n",
    "In practice, we can't compute the above expectation. Therefore, we approximate it with the following one-sample Monte-Carlo estimate:\n",
    "\\begin{align*}\n",
    "\\log p(x \\mid z_0, \\phi) + \\log p(z_0) - \\log q(z_0 \\mid x, \\theta) \\\\\n",
    "z_0 = \\mu(x, \\phi) + \\sigma^2(x, \\phi)^T \\varepsilon_0 \\\\\n",
    "\\varepsilon_0 \\sim \\mathcal N(0, I)\n",
    "\\end{align*}\n",
    "**Note that this choice of the Monte-Carlo estimate for expectation is crucial and is typically reffered to as reparametrization trick.** For more details see [Auto-encoding Variational Bayes](https://arxiv.org/abs/1312.6114) paper.\n",
    "\n",
    "Finally, to train the model we average the lower bound values over the minibatch and then maximize the average with gradient ascent:\n",
    "\n",
    "$$ \\frac{1}{N} \\sum_{n=1}^N \\log p(x_n \\mid z_n, \\phi) + \\log p(z_n) - \\log q(z_n \\mid x_n, \\theta) \\rightarrow \\max_{\\theta, \\phi} $$\n",
    "\n",
    "## Encoder and decoder\n",
    "\n",
    "$q(z\\mid x, \\theta)$ is usually referred to as encoder and $p(x \\mid z, \\phi)$ is usually reffered to as decoder. To parametrize these distributions we introduce two neural networks:\n",
    "\n",
    "- *enc* takes $x$ as input and return $2 \\times d$-dimensional vector to parametrize mean and standard deviation of $q(z \\mid x, \\theta)$\n",
    "- *dec* takes a latent representation $z$ and returns the logits of distribution $p(x \\mid z, \\phi)$.\n",
    "\n",
    "The computational graph has a simple structure of autoencoder. The only difference is that now it uses a stochastic variable $\\varepsilon$:\n",
    "\n",
    "![vae](vae.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d, nh, D = 32, 100, 28 * 28\n",
    "\n",
    "enc = nn.Sequential(\n",
    "    nn.Linear(D, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, 2 * d)) # mind the non-linearities at the final layer\n",
    "\n",
    "dec = nn.Sequential(\n",
    "    nn.Linear(d, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, D)) # <-----------------------------------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss function\n",
    "\n",
    "Implement the loss function for the variational autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(x, encoder, decoder):\n",
    "    ##################################\n",
    "    #         implement it           #\n",
    "    ##################################\n",
    "    # return an average of elbo terms over the minibatch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain\n",
    "\n",
    "def train_model(encoder, decoder, batch_size=100, num_epochs=3, learning_rate=1e-3):\n",
    "    gd = optim.Adam(chain(encoder.parameters(), decoder.parameters()), lr=learning_rate)\n",
    "    dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)\n",
    "    train_losses = []\n",
    "    test_results = []\n",
    "    for _ in range(num_epochs):\n",
    "        for i, (batch, _) in enumerate(dataloader):\n",
    "            total = len(dataloader)\n",
    "            loss_value = loss(batch, encoder, decoder)\n",
    "            (-loss_value).backward()\n",
    "            train_losses.append(loss_value.cpu().item())\n",
    "            if (i + 1) % 10 == 0:\n",
    "                print('\\rTrain loss:', train_losses[-1],\n",
    "                      'Batch', i + 1, 'of', total, ' ' * 10, end='', flush=True)\n",
    "            gd.step()\n",
    "            gd.zero_grad()\n",
    "        test_elbo = 0.\n",
    "        for i, (batch, _) in enumerate(test_dataloader):\n",
    "            batch_elbo = loss(batch, encoder, decoder)\n",
    "            test_elbo += (batch_elbo - test_elbo) / (i + 1)\n",
    "        print('\\nTest loss after an epoch: {}'.format(test_elbo))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# my implementation has test loss = -110.31\n",
    "train_model(enc, dec, num_epochs=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualisations\n",
    "\n",
    "- How do reconstruction compare to reconstructions of autoencoder?\n",
    "- Interpolations?\n",
    "- Is the latent space regularly covered? \n",
    "- Is there any dependence between T-SNE encoding and the digit label?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_reconstructions(encoder, decoder):\n",
    "    batch = test_data[np.random.choice(10000, 25)][0]\n",
    "    rec = nn.functional.sigmoid(decoder(encoder(batch)[:, :d]))\n",
    "    rec = rec.view(-1, 28, 28).data\n",
    "    batch = batch.view(-1, 28, 28).numpy()\n",
    "    \n",
    "    fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(14, 7),\n",
    "                             subplot_kw={'xticks': [], 'yticks': []})\n",
    "    for i in range(25):\n",
    "        axes[i % 5, 2 * (i // 5)].imshow(batch[i], cmap='gray')\n",
    "        axes[i % 5, 2 * (i // 5) + 1].imshow(rec[i], cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_reconstructions(enc, dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_interpolations(encoder, decoder):\n",
    "    batch = encoder(test_data[np.random.choice(10000, 10)][0])\n",
    "    z_0 = batch[:5, :d].view(5, 1, d)\n",
    "    z_1 = batch[5:, :d].view(5, 1, d)\n",
    "    alpha = torch.tensor(np.linspace(0., 1., 10), dtype=torch.float32)\n",
    "    alpha = alpha.view(1, 10, 1)\n",
    "    interpolations_z = (z_0 * alpha + z_1 * (1 - alpha))\n",
    "    interpolations_z = interpolations_z.view(50, d)\n",
    "    interpolations_x = nn.functional.sigmoid(decoder(interpolations_z))\n",
    "    interpolations_x = interpolations_x.view(5, 10, 28, 28).data.numpy()\n",
    "    \n",
    "    fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(14, 7),\n",
    "                             subplot_kw={'xticks': [], 'yticks': []})\n",
    "    for i in range(50):\n",
    "        axes[i // 10, i % 10].imshow(interpolations_x[i // 10, i % 10], cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_interpolations(enc, dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_tsne(objects, labels):\n",
    "    from sklearn.manifold import TSNE\n",
    "    embeddings = TSNE(n_components=2).fit_transform(objects)\n",
    "    plt.figure(figsize=(8, 8))\n",
    "    for k in range(10):\n",
    "        embeddings_for_k = embeddings[labels == k]\n",
    "        plt.scatter(embeddings_for_k[:, 0], embeddings_for_k[:, 1],\n",
    "                    label='{}'.format(k))\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_variables = enc(test_data[:1000][0])[:, :d]\n",
    "latent_variables = latent_variables.data.numpy()\n",
    "labels = test_data[:1000][1].numpy()\n",
    "\n",
    "plot_tsne(latent_variables, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semi-supervised VAE model\n",
    "\n",
    "This part of notebook is inspired by [\"Semi-supervised Learning with\n",
    "Deep Generative Models\"](https://arxiv.org/pdf/1406.5298.pdf). We will also use this model to illustrate the Gumbel-Softmax distribution."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the semi-supervised setting, the generative model is a little more complicated. In particular, it incorporates a new variable $y$ that represents the digits class.\n",
    "\n",
    "\\begin{align*}\n",
    "& p(x, y, z) = p(x \\mid y, z) p(z) p(y) \\\\\n",
    "& p(y) = Cat(y \\mid \\pi_0), \\pi_0 = (1/10, \\dots, 1/10) \\\\\n",
    "& p(z) = \\mathcal N(z \\mid 0, I) \\\\\n",
    "& p(x \\mid y, z) = \\prod_{i=1}^D p_i(y, z)^{x_i} (1 - p_i(y, z))^{1 - x_i}\n",
    "\\end{align*}\n",
    "\n",
    "Typically, whenever we train a model with partial observations, we interpret unobserved variables as latent variables and marginalize over them. In this case, the loss function splits into two terms: one for observed variables (we denote the set of indices of observed labels $P$), another for unobserved.\n",
    "\n",
    "\\begin{equation}\n",
    "L(X, y) = \\sum_{i \\notin P} \\log p(x_i) + \\sum_{i \\in P} \\log p(x_i, y_i)\n",
    "\\end{equation}\n",
    "\n",
    "Again, we can't compute the exact values of marginal likelihoods and resort to variational lower bound on likelihood. To compute lower bounds we define the following variational approximation:\n",
    "\n",
    "\\begin{align*}\n",
    "& q(y, z \\mid x) = q(y \\mid x) q(z \\mid y, x)\\\\\n",
    "& \\\\\n",
    "& q(y \\mid x) = Cat(y \\mid \\pi(x))\\\\\n",
    "& q(z \\mid y, x) = \\mathcal N(z \\mid \\mu_\\phi(x, y), \\operatorname{diag}\\sigma^2_\\phi(y, x))\n",
    "\\end{align*}\n",
    "\n",
    "### ELBO for observed variables\n",
    "\n",
    "Similiar to VAE:\n",
    "\n",
    "\\begin{equation}\n",
    "\\log p(x, y) = \\log \\mathbb E_{p(z)} p(x, y \\mid z) \\geq \\mathbb E_{q(z \\mid y, x)} \\log \\frac{p(x, y \\mid z) p(z)}{q(z \\mid y, x)}\n",
    "\\end{equation}\n",
    "\n",
    "### ELBO for unobserved variables\n",
    "\n",
    "\\begin{equation}\n",
    "\\log p(x) = \\log \\mathbb E_{p(y)} \\mathbb E_{p(z \\mid y)} \\log p(x\\mid z, y)\\geq \\mathbb E_{q(y \\mid x)} \\mathbb E_{q(z \\mid y, x)} \\log \\frac{p(x, y \\mid z) p(z)}{q(z \\mid y, x) q(y \\mid x)}\n",
    "\\end{equation}\n",
    "\n",
    "### The final objective\n",
    "\n",
    "\\begin{equation}\n",
    "\\mathcal L(X, y) = \\sum_{i \\in P} \\mathbb E_{q(z_i \\mid y_i, x_i)} \\log \\frac{p(x_i, y_i \\mid z_i) p(z_i)}{q(z_i \\mid y_i, x_i)} + \\sum_{i \\notin P} \\mathbb E_{q(y_i \\mid x_i)} \\mathbb E_{q(z_i \\mid y_i, x_i)} \\log \\frac{p(x_i, y_i \\mid z_i) p(z_i)}{q(z_i \\mid y_i, x_i) q(y_i \\mid x_i)}\n",
    "\\end{equation}\n",
    "\n",
    "Again, we will use reparametrized Monte-Carlo estimates to approximate expectation over $z$. To approximate expectaion over discrete variables $y$ we will use Gumbel-Softmax trick.\n",
    "\n",
    "# Important practical aspect\n",
    "\n",
    "ELBO maximization does not lead to any semantics in latent variables $y$. \n",
    "\n",
    "We are going to restrict variational approximations $q(y \\mid x)$ to the ones that correctly classify observation $x$ on fully-observed variables $(x_i, y_i)$. As in the original paper, we will add a cross-entropy regularizer to the objective with weight $\\alpha$:\n",
    "\n",
    "\\begin{equation}\n",
    "\\frac{1}{|P|}\\sum_{i \\in P} y_i^T \\log q(y \\mid x).\n",
    "\\end{equation}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RelaxedOneHotCategorical\n",
    "\n",
    "In the probabilistic model defined above we are going to replace categorical prior $p(y)$ and categorical variational approximation $q(y | x)$ with Gumbel-Softmax distribution. The distribution class is implemented below:\n",
    "\n",
    "For more details see [Categorical Reparameterization with Gumbel-Softmax](https://arxiv.org/abs/1611.01144)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelaxedOneHotCategorical():\n",
    "    def __init__(self, logits, temperature):\n",
    "        self.k = torch.tensor([logits.shape[1]], dtype=torch.float32)\n",
    "        self.logits = logits\n",
    "        self.temperature = torch.tensor([temperature])\n",
    "\n",
    "    def log_prob(self, x):\n",
    "        log_Z = (torch.lgamma(self.k) + (self.k - 1) * self.temperature.log())\n",
    "        log_prob_unnormalized = (nn.functional.log_softmax(\n",
    "            self.logits - self.temperature * x.log(), dim=1) - x.log()).sum(1)\n",
    "        return log_prob_unnormalized + log_Z\n",
    "    \n",
    "    def sample(self):\n",
    "        gumbel = -(-torch.rand_like(self.logits).log()).log()\n",
    "        sample = nn.functional.softmax((self.logits + gumbel) / self.temperature, dim=1)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An illustration for Gumbel-Softmax\n",
    "\n",
    "- Temperature allows for smooth interpolation between one-hot categorical distribution with low temperature and a $(1/K, \\dots, 1/K)$ vector with high temperatures\n",
    "- The exact computation of $\\mathbb E_{q(y|x)} f(y)$ requires computation of $f(y)$ for ten possible labels $y=0, \\dots, 9$. On the other hand, with Gumbel-Softmax relaxation only one sample $y \\sim q(y | x)$ is enough. Therefore, Gumbel-Softmax gives almost a ten-fold training speed increase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.cm as cm\n",
    "n_classes = 4\n",
    "logits = torch.randn(1, n_classes)\n",
    "temperatures = [0.1, 0.5, 1., 5., 10.]\n",
    "M = 128 # number of samples used to approximate distribution mean\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=len(temperatures), figsize=(14, 6),\n",
    "                         subplot_kw={'xticks': range(n_classes),\n",
    "                                     'yticks': [0., 0.5, 1.]})\n",
    "axes[0, 0].set_ylabel('Expectation')\n",
    "axes[1, 0].set_ylabel('Gumbel Softmax Sample')\n",
    "\n",
    "for n, t in enumerate(temperatures):\n",
    "    dist = RelaxedOneHotCategorical(logits, t)\n",
    "    mean = torch.zeros_like(logits)\n",
    "    for _ in range(M):\n",
    "        mean += dist.sample() / M\n",
    "    sample = dist.sample()\n",
    "    \n",
    "    axes[0, n].set_title('T = {}'.format(t))\n",
    "    axes[0, n].set_ylim((0, 1.1))\n",
    "    axes[1, n].set_ylim((0, 1.1))\n",
    "    axes[0, n].bar(np.arange(n_classes), mean.numpy().reshape(n_classes),\n",
    "                   color=cm.plasma(0.75 * t / max(temperatures)))\n",
    "    axes[1, n].bar(np.arange(n_classes), sample.numpy().reshape(n_classes),\n",
    "                   color=cm.plasma(0.75 * t / max(temperatures)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SS-VAE implementation\n",
    "\n",
    "The computational graph for observed labels has the following structure:\n",
    "\n",
    "![computational graph ss vae xy](ss_vae_xy.png)\n",
    "\n",
    "The computational graph for unobserved lables has the following structure:\n",
    "\n",
    "![computational graph ss vae xy](ss_vae_x.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes, d, nh, D = 10, 32, 500, 28 * 28\n",
    "\n",
    "yz_dec = nn.Sequential(\n",
    "    nn.Linear(n_classes + d, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, D))\n",
    "\n",
    "y_enc = nn.Sequential(\n",
    "    nn.Linear(D, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, n_classes))\n",
    "\n",
    "z_enc = nn.Sequential(\n",
    "    nn.Linear(n_classes + D, nh),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(nh, 2 * d)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(x, y, y_encoder, z_encoder, decoder, T=0.6, alpha=32.):#, verbose=False):\n",
    "    ##################################\n",
    "    #         implement it           #\n",
    "    ##################################\n",
    "    \n",
    "    ################################################################################\n",
    "    # NOTE:                                                                        #\n",
    "    # hyperparameter alpha was tuned for the implementation that computed  mean of #\n",
    "    # elbo terms and sum of cross-entropy terms over observed datapoints in batch  #\n",
    "    ################################################################################\n",
    "    \n",
    "    # sample y from q(y | x)\n",
    "    \n",
    "    # sample z from q(z | x, y)\n",
    "    \n",
    "    # compute the evidence lower bound for obervsed and unobserved variables\n",
    "    \n",
    "    # compute the cross_entropy regularizer with weight alpha\n",
    "    \n",
    "    # return the sum of two losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain\n",
    "\n",
    "def train_model(y_encoder, z_encoder, decoder, batch_size=100, num_epochs=3, learning_rate=1e-3):\n",
    "    gd = optim.Adam(chain(y_encoder.parameters(),\n",
    "                          z_encoder.parameters(),\n",
    "                          decoder.parameters()), lr=learning_rate)\n",
    "    dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)\n",
    "    train_losses = []\n",
    "    for _ in range(num_epochs):\n",
    "        for i, (x, y) in enumerate(dataloader):\n",
    "            total = len(dataloader)\n",
    "            loss_value = loss(x, y, y_encoder, z_encoder, decoder)\n",
    "            (-loss_value).backward()\n",
    "            train_losses.append(loss_value.cpu().item())\n",
    "            if (i + 1) % 10 == 0:\n",
    "                print('\\rTrain loss:', train_losses[-1],\n",
    "                      'Batch', i + 1, 'of', total, ' ' * 10, end='', flush=True)\n",
    "            gd.step()\n",
    "            gd.zero_grad()\n",
    "        loss_value = 0.\n",
    "        accuracy = 0.\n",
    "        for i, (x, y) in enumerate(test_dataloader):\n",
    "            total = len(test_dataloader)\n",
    "            loss_value += loss(x, torch.zeros((y.shape[0], 10)), y_encoder, z_encoder, decoder).item()\n",
    "            accuracy += (torch.argmax(y_encoder(x), 1) == y).double().mean().item()\n",
    "        print('Test loss: {}\\t Test accuracy: {}'.format(loss_value / total, accuracy / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Neural networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# my implementation omitted log p(y) for observed variables. it has\n",
    "# test loss -106.79\n",
    "# test accuracy 0.95\n",
    "train_model(y_enc, z_enc, yz_dec, num_epochs=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations\n",
    "\n",
    "Generate 10 images for each label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_samples_with_fixed_classes(dec):\n",
    "    decoder_input = torch.cat((torch.eye(10).repeat(10, 1), torch.randn(100, d)), 1)\n",
    "    images = nn.functional.sigmoid(dec(decoder_input)).view(100, 28, 28).data.numpy()\n",
    "    \n",
    "    fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(14, 14),\n",
    "                             subplot_kw={'xticks': [], 'yticks': []})\n",
    "    for i in range(10):\n",
    "        axes[0, i].set_title('{}'.format(i))\n",
    "    \n",
    "    for i in range(100):\n",
    "        axes[int(i / 10), i % 10].imshow(images[i], cmap='gray')\n",
    "        \n",
    "plot_samples_with_fixed_classes(yz_dec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"Style-transfer\"\n",
    "\n",
    "Here we infer latent representation $z$ of a given digit $x$ and then generate from $p(x | z, y)$ for different $y$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_digits_with_fixed_style(z_enc, y_enc, dec):\n",
    "    indices = np.random.choice(10000, 10)\n",
    "    x, y = test_data[indices][0], torch.eye(10)[test_data[indices][1]]\n",
    "    z = z_enc(torch.cat((x, y), 1))[:, :d]\n",
    "\n",
    "    # generate digits\n",
    "    images = []\n",
    "    for i in range(10):\n",
    "        digit_encodings = torch.eye(10)[i, :].expand(10, 10)\n",
    "        images.append(nn.functional.sigmoid(dec(torch.cat((digit_encodings, z), 1)).view(10, 28, 28)).data.numpy())\n",
    "        \n",
    "    x = x.view(10, 28, 28).numpy()\n",
    "\n",
    "    # plot\n",
    "    fig, axes = plt.subplots(nrows=10, ncols=11, figsize=(14, 14),\n",
    "                             subplot_kw={'xticks': [], 'yticks': []})\n",
    "    \n",
    "    axes[0, 0].set_title('example')\n",
    "    for i in range(10):\n",
    "        axes[0, i + 1].set_title('{}'.format(i))\n",
    "        axes[i, 0].imshow(x[i], cmap='gray')\n",
    "        for j in range(10):\n",
    "            axes[i, j + 1].imshow(images[j][i], cmap='gray')\n",
    "            \n",
    "plot_all_digits_with_fixed_style(z_enc, y_enc, yz_dec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### T-SNE for SS-VAE\n",
    "\n",
    "Do you notice any difference from T-SNE for vanilla VAE? How can you interpret the results?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# T-SNE for q(z | x, y) mean\n",
    "labels = test_data[:1000][1].numpy()\n",
    "encoder_input = torch.cat((test_data[:1000][0], torch.eye(10)[labels]), 1)\n",
    "latent_variables = z_enc(encoder_input)[:, :d]\n",
    "latent_variables = latent_variables.data.numpy()\n",
    "\n",
    "plot_tsne(latent_variables, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# T-SNE for q(y | x) logits\n",
    "labels = test_data[:1000][1].numpy()\n",
    "latent_variables = y_enc(test_data[:1000][0])\n",
    "latent_variables = latent_variables.data.numpy()\n",
    "\n",
    "plot_tsne(latent_variables, labels)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
